import torch
import math
import numpy as np
from bgflow.utils import (
    remove_mean,
    IndexBatchIterator,
)
from bgflow import (
    DiffEqFlow,
    BoltzmannGenerator,
    MeanFreeNormalDistribution,
    BlackBoxDynamics,
)
import ot as pot
from eq_ot_flow.models import EGNN_dynamics_AD2_cat
from eq_ot_flow.estimator import BruteForceEstimatorFast

from bgmol.datasets import AImplicitUnconstrained
import mdtraj as md
from scipy.stats import vonmises
import torch
import numpy as np
import os
import click
import wandb
from bgflow.utils import (
    IndexBatchIterator,
)
from bgflow import (
    DiffEqFlow,
    BoltzmannGenerator,
    MeanFreeNormalDistribution,
    BlackBoxDynamics,
)
import ot as pot
from eq_ot_flow.estimator import BruteForceEstimatorFast
import json

from path_grad_helpers import (
    HutchinsonEstimatorDifferentiable,
    AugmentedAdjointDyn,
    path_gradient,
    train_loop,
    device,
    load_weights,
    fm_train_step_ot,
)


@click.command()
@click.option("--n_batch", default=1024)
@click.option("--n_epochs", default=1001)
@click.option("--lr", default=5e-4)
@click.option("--n_knots_hutch", default=20)
@click.option(
    "--training_kind",
    default="fm",
    type=click.Choice(["fm", "eq-fm", "path"]),
    help="How do you want to train the model?",
)
@click.option("--data_path", default="/data")
@click.option("--chkpt_path", default=None)
@click.option("--force_clipping", default=False)
@click.option("--grad_clipping", default=True)
@click.option("--use_xtb", default=True)
@click.option("--reweight_data", default=False)
@click.option("--transferable", default=False)
@click.option("--grad_acc_steps", default=1)
def main(
    n_report_steps=5,
    n_batch=1024,
    n_epochs=1001,
    n_holdout=500000,
    lr=5e-4,
    n_knots_hutch=20,
    training_kind="fm",
    data_path="/data",
    chkpt_path=None,
    force_clipping=True,
    grad_clipping=True,
    use_xtb=True,
    reweight_data=False,
    transferable=False,
    grad_acc_steps=1,
):
    print(f"Batch-size {n_batch * max(1, grad_acc_steps)}")
    dataset = AImplicitUnconstrained(read=True)

    n_particles = 22
    n_dimensions = 3
    dim = n_particles * n_dimensions
    kappa = 10

    target = dataset.get_energy_model()
    # instead of this weird scaling here, we could just do it in the flow, right?
    # Seems like the more elegant solution, since that way the data are the data and scaling is part of the model
    unscaled_energy = getattr(target, "energy")
    setattr(target, "energy", lambda x: (unscaled_energy(x / scaling)))

    ala_traj = md.Trajectory(dataset.xyz, dataset.system.mdtraj_topology)

    # Intialize either the transferable or the non-transferable BG
    if transferable:
        # atom types for tbg
        atom_types = np.arange(22)
        atom_types[[1, 2, 3]] = 2
        atom_types[[19, 20, 21]] = 20
        atom_types[[11, 12, 13]] = 12
    else:
        atom_dict = {"H": 0, "C": 1, "N": 2, "O": 3}
        atom_types = []
        for atom_name in ala_traj.topology.atoms:
            atom_types.append(atom_name.name[0])
        atom_types = np.array([atom_dict[atom_type] for atom_type in atom_types])

        # Make the backbone atoms distingushiable
        atom_types[[4, 6, 8, 14, 16]] = np.arange(4, 9)

    h_initial = torch.nn.functional.one_hot(torch.tensor(atom_types))

    scaling = 10
    if use_xtb:
        data_smaller = (
            torch.from_numpy(np.load(f"{data_path}/AD2_relaxed.npy")).float() / 10
        )
    else:
        data_smaller = (
            torch.from_numpy(np.load(f"{data_path}/AD2_classical.npy")).float() / 10
        )
    data_smaller = (
        remove_mean(data_smaller, n_particles, n_dimensions).reshape(-1, dim) * scaling
    )

    kappa = 10
    ala_traj = md.Trajectory(
        data_smaller.cpu().numpy().reshape(-1, n_particles, n_dimensions),
        dataset.system.mdtraj_topology,
    )

    # Reweighting
    phi = md.compute_phi(ala_traj)[1].flatten()
    if reweight_data:
        weights = 150 * vonmises.pdf(phi - 1.0, kappa) + 1

        data_smaller_weighted_idx = np.random.choice(
            np.arange(len(data_smaller)),
            len(data_smaller),
            p=weights / weights.sum(),
            replace=True,
        )
        data_smaller = data_smaller[data_smaller_weighted_idx]

    grad_acc_steps = max(1, grad_acc_steps)

    config = {
        "n_report_steps": n_report_steps,
        "n_batch": n_batch,
        "n_epochs": n_epochs,
        "n_holdout": n_holdout,
        "lr": lr,
        "n_knots_hutch": n_knots_hutch,
        "training_kind": training_kind,
        "chkpt_path": chkpt_path,
        "force_clipping": force_clipping,
        "grad_clipping": grad_clipping,
        "use_xtb": use_xtb,
        "reweight_data": reweight_data,
        "transferable": transferable,
        "grad_acc_steps": grad_acc_steps,
        "full batch size": n_batch * grad_acc_steps,
    }

    wandb.init(project=f"PathGradFlowMatching-AD2", config=config)

    if chkpt_path is not None:
        log_folder = f"{chkpt_path[:-3]}-{'T' if transferable else ''}BG-{n_batch}x{grad_acc_steps}-{training_kind}-{n_epochs}-{lr}-{n_knots_hutch}-Clip{grad_clipping}{force_clipping}-XTB{use_xtb}-weight{reweight_data}"
    else:
        log_folder = f"models/AD2V3-{'T' if transferable else ''}BG-{n_batch}-{grad_acc_steps}-{training_kind}-{n_epochs}-{lr}-knots{n_knots_hutch}-Clip{grad_clipping}{force_clipping}-XTB{use_xtb}-weight{reweight_data}"

    print(f"Creating folder {log_folder}")
    os.mkdir(f"{log_folder}")
    json.dump(config, open(f"{log_folder}/config.json", "w"))

    print("Loading data")
    print(f"Dataset size {data_smaller.shape}")

    # now set up a prior
    prior = MeanFreeNormalDistribution(dim, n_particles, two_event_dims=False).to(
        device
    )

    print("Building Flow")
    # Build the Boltzmann Generator
    net_dynamics = EGNN_dynamics_AD2_cat(
        n_particles=n_particles,
        device=device,
        n_dimension=dim // n_particles,
        h_initial=h_initial,
        hidden_nf=64,
        act_fn=torch.nn.SiLU(),
        n_layers=5,
        recurrent=True,
        tanh=True,
        attention=True,
        condition_time=True,
        mode="egnn_dynamics",
        agg="sum",
    )

    bb_dynamics = BlackBoxDynamics(
        dynamics_function=net_dynamics, divergence_estimator=BruteForceEstimatorFast()
    )
    flow = DiffEqFlow(dynamics=bb_dynamics)
    # having a flow and a prior, we can now define a Boltzmann Generator

    bg = BoltzmannGenerator(prior, flow, target.to(device))
    if chkpt_path is not None:
        load_weights(bg, chkpt_path)

    print("Setting up training loop")
    batch_iter = IndexBatchIterator(len(data_smaller), n_batch)

    optim = torch.optim.Adam(bg.parameters(), lr=lr)

    sigma = 0.01

    def batches():
        for idxs in batch_iter:
            yield data_smaller[idxs].to(device)

    if training_kind == "fm":
        print("Setting up fm trainer")
        fm_trainer = lambda x1: fm_train_step_ot(x1, prior, bg, pot, sigma=sigma)
        batch_it = batches
        print("Starting training")
    else:
        print("Setting up Path grads")
        path_grad_dynamics = AugmentedAdjointDyn(
            BlackBoxDynamics(
                dynamics_function=net_dynamics,
                divergence_estimator=HutchinsonEstimatorDifferentiable(),
            )
        )
        flow_hutch = DiffEqFlow(
            dynamics=path_grad_dynamics,
            integrator="rk4",
            n_time_steps=n_knots_hutch,
        )

        # Here bg is not used, since for training we use a different integrator/adjoint ode
        fm_trainer = lambda x1: path_gradient(
            x1,
            prior,
            target,
            flow_hutch,
            x_includes_grads=use_xtb,
            force_clipping=force_clipping,
        )

        if use_xtb:
            data_forces = torch.from_numpy(
                np.load(f"{data_path}/AD2_relaxed_forces.npy")
            ).float()
            if reweight_data:
                data_forces = data_forces[data_smaller_weighted_idx]

            def both_batches():
                for idxs in batch_iter:
                    yield data_smaller[idxs].to(device), data_forces[idxs].to(device)

            batch_it = both_batches
        else:
            batch_it = batches
    print("Starting training")

    train_loop(
        n_epochs,
        bg,
        fm_trainer,
        batch_it,
        optim,
        grad_clipping=grad_clipping,
        grad_acc_steps=grad_acc_steps,
        batches_per_epoch=math.ceil(len(data_smaller) / n_batch),
    )
    print("Finished training")
    torch.save(bg.state_dict(), f"{log_folder}/chkpt.pt")
    print("Saved and finished")


if __name__ == "__main__":
    main()
